/*
  Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

  This file is part of Sunrise.

  Sunrise is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 3 of the License, or
  (at your option) any later version.

  Sunrise is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.
  
*/

#include "almost_equal.h"
#include<iostream>
#include "CCfits/CCfits"
#include "counter.h"
#include "polychromatic_grid.h"
#include "optical.h"
#include "scatterer.h"
#include "constants.h"
#include "dust_model.h"
#include "emission.h"
#include "emission-fits.h"
#include "full_sed_emergence.h"
#include "full_sed_grid.h"
#include "emergence-fits.h"
#include "mcrx-units.h"
#include "grid_factory.h"
#include "grid-fits.h"
#include "shoot.h"
#include "fits-utilities.h"
#include "boost/shared_ptr.hpp"
#include "density_generator.h"
#include "preferences.h"
#include "model.h"
#include "wd01_Brent_PAH_grain_model.h"
#include "wd01_grain_model.h"
#include "xfer.h"
#include "misc.h"

using namespace mcrx;
using namespace blitz;
using namespace CCfits;
using namespace std; 

typedef local_random T_rng_policy;
typedef scatterer<polychromatic_scatterer_policy, T_rng_policy> T_scatterer;
typedef T_scatterer::T_lambda T_lambda;
typedef T_scatterer::T_biaser T_biaser;
typedef T_polychromatic_dust_model T_dust_model;
typedef full_sed_emergence T_emergence;

typedef full_sed_grid<adaptive_grid<absorber<array_1> > > T_grid;
typedef emission<polychromatic_policy, T_rng_policy> T_emission;

// prevent instantiation of the xfer template here, it's already in
//libmcrx. this lowers the -O0 compilation time from 6min to 2m...
extern template class mcrx::xfer<T_dust_model, T_grid>;
extern template class full_sed_grid<adaptive_grid<absorber<array_1> > >;

class uniform_factory {
public:
  typedef typename T_grid::T_data T_data;
  typedef refinement_accuracy_data<T_data> T_racc;
  typedef cell_tracker<T_data> T_cell_tracker;

  const T_densities rho_;
  const int n_lambda;
  const int n_threads_;

  uniform_factory (const T_densities& rho, int nl, int nt):
    rho_ (rho), n_lambda (nl), n_threads_(nt) {};
  
  bool refine_cell_p (const T_cell_tracker& c) {
    return false;
  };
  bool unify_cell_p (const T_cell_tracker& c, const T_racc& racc) {
    return false;
  };
  T_data get_data (const T_cell_tracker& c) {
    return T_data(rho_, vec3d(0,0,0), T_lambda(n_lambda));
  }
    
  virtual int n_threads () {return n_threads_;};
};


vector<T_float> watson_henney_test (int minscat, int maxscat,
				    int seed, bool laser)
{  
  const T_float rho=2;
  const T_float emitter_z = 0;
  const vec3d grid_min(-1000,-1000,0);
  const vec3d grid_max(1000,1000,1);
  const T_float distance = 10000;
  const int n_threads=12;
  const int n_lambda=1;
  const int n_rays=1000000;

  T_lambda lambda(n_lambda);lambda=1;

  mcrx::seed(seed);
  T_unit_map units;
  units ["length"] = "kpc";
  units ["mass"] = "Msun";
  units ["luminosity"] = "W";
  units ["wavelength"] = "m";
  units ["L_lambda"] = "W/m";

  // use a simple HG grain here since we just want a single wavelength
  // and no IR emission.
  T_lambda kk(n_lambda); kk=1;
  T_lambda aa(n_lambda); aa=0.5;
  T_lambda gg(n_lambda); gg=0.5;
  T_dust_model::T_scatterer_vector sv;
  sv.push_back(boost::shared_ptr<T_scatterer> 
	       (new simple_HG_dust_grain<polychromatic_scatterer_policy, mcrx_rng_policy> (kk,gg,aa)));
  T_dust_model model (sv);
  model.set_wavelength(lambda);
  
  const T_float albedo = model.get_scatterer(0 ).albedo()(0);

  T_densities rhov(n_lambda);rhov=rho;
  uniform_factory factory (rhov, n_lambda, n_threads);
  // first create the adaptive grid
  boost::shared_ptr<adaptive_grid<T_grid::T_data> > tempgrid
    (new adaptive_grid<T_grid::T_data>  
     (grid_min, grid_max, factory));

  // and then the full_sed_grid
  T_grid gr(tempgrid, units);

  cout << "setting up emission" << endl;
  auto_ptr<T_emission> em;
  if(laser)
    em.reset(new laser_emission<polychromatic_policy, T_rng_policy> 
	     (vec3d (0., 0., emitter_z), vec3d(0,0,1), lambda));
  else
    em.reset(new pointsource_emission<polychromatic_policy, T_rng_policy> 
	     (vec3d (0., 0., emitter_z), lambda));
  
  vector<pair<T_float, T_float> > campos;
  for(T_float theta=0; theta<=180; theta+=10) {
    // because their theta=90 is on the half-plane where the source is
    // visible, we use a slightly offset theta
    if(abs(theta-90)<1e-5)
      campos.push_back(make_pair(0.50000000001*constants::pi,0.));    
    else
      campos.push_back(make_pair(theta*constants::pi/180,0.));
  }

  auto_ptr<T_emergence> e(new T_emergence(campos, distance, 0.1*distance, 1));

  std::vector<T_rng::T_state> states= generate_random_states (n_threads);
  T_rng_policy rng;

  T_biaser b(0);
  xfer<T_dust_model, T_grid> x (gr, *em, *e, model, rng, b, 1e-2, 100, false, false, minscat, maxscat);
  long n_rays_actual = 0;
  dummy_terminator t;
  shoot (x, scatter_shooter(), t, states, n_rays,
  	 n_rays_actual, n_threads, true, n_threads,0, "");

  // convert images to flux by multiplying by the solid angle
  // subtended by the pixels.
  vector<T_float> result;
  for (T_emergence::iterator i = e->begin(); i != e->end(); ++i) {
    const T_emergence::T_camera::T_image image =
      (*i)->get_image ();
    array_3 temp(image*(*i)->pixel_solid_angle_image()(tensor::i, tensor::j));
    const T_float scatter = sum (temp)*distance*distance;
    result.push_back(scatter);
    const T_float theta =(*i)->get_theta();
  }
  return result;
}

/** Make comparison. res is the MC output, theo is the theoretical
    results. ncol is number of columns to compare, roffset is the
    offset into the results where the comparison starts, toffset the
    column offset into the theoretical results. */
bool print_test(const vector<vector<T_float> >& res, const array_2& theo,
		int ncol, int roffset, int toffset, bool comp)
{
  T_float maxdiff=0;
  bool pass=true;
  if(comp) cout << "Fractional difference against WH01:\n";
  for(int i=0; i<res[0].size();++i) {
    for(int j=0; j<ncol;++j) {
      T_float fracdiff= res[j+roffset][i]/theo(i,j+toffset)-1;
      maxdiff=max(maxdiff, abs(fracdiff));
      bool pass_here = almostEqual(res[j+roffset][i],theo(i,j+toffset),10000L) |
	(fracdiff<1e-2);
      pass &= pass_here;

      cout << 
	(comp ? fracdiff : res[j+roffset][i] )
	   << '\t';
      if(comp && !pass_here) cout << " (fail) ";
    }
    cout << endl;
  }
  cout << endl;
  if(comp) {
    cout << "Max fractional difference: " << maxdiff << endl;
    if(!pass)
      cout << "TEST FAILED!\n";
    else
      cout << "Test passed.\n";
    return pass;
  }
  return true;
}

bool test_scattered_intensity ()
{
  cout << "Running Watson & Henney 2001 scattered intensity test\n\n";

  const string whfn("$HOME/tests_sunrise/wh01/wh01results.txt");
  ifstream whf(word_expand(whfn)[0].c_str());
  if(!whf) {
    cerr << "Could not open file with Watson-Henney results: " << whfn << endl;
    return 1;
  }
  vector<T_float> whd((istream_iterator<T_float>(whf)),
			 istream_iterator<T_float>());
  assert(whd.size()==9*19);
  array_2 whdata(&whd[0], shape(19,9), neverDeleteData);

  vector<vector<T_float> > res;
  res.push_back(watson_henney_test (0,0, 42, false));
  res.push_back(watson_henney_test (1,1, 42, false));
  res.push_back(watson_henney_test (2,2, 42, false));
  res.push_back(watson_henney_test (3,blitz::huge(int()), 42, false));

  res.push_back(watson_henney_test (0,0, 42, true));
  res.push_back(watson_henney_test (1,1, 42, true));
  res.push_back(watson_henney_test (2,2, 42, true));
  res.push_back(watson_henney_test (3,blitz::huge(int()), 42, true));

  bool pass=true;
  cout << "\n\nPoint source test:\n";
  print_test(res,whdata,4,0,1,false);
  pass &=print_test(res,whdata,4,0,1,true);

  cout << "\n\nPencilbeam test:\n";
  print_test(res,whdata,4,4,5,false);
  pass &=print_test(res,whdata,4,4,5,true);

  return pass;
}

int main (int argc, char** argv)
{
  mcrx::mpienv::init(argc, argv);
  hpm::disable();

  return !test_scattered_intensity();
}
